# -*- coding: utf-8 -*-
"""
Created on Sat Jan 30 13:34:11 2021

@author: wzhangcd
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module

import math
import numpy as np
import pandas as pd

import sys
sys.path.append('./lib/')
from pkl_process import *
from utils import load_graphdata_channel_my, compute_val_loss_sttn, masked_mape_np, re_normalization, max_min_normalization, re_max_min_normalization

from time import time
import shutil
import argparse
import configparser
from tensorboardX import SummaryWriter
import os

from ST_Transformer_new import STTransformer # STTN model with linear layer to get positional embedding
from ST_Transformer_new_sinembedding import STTransformer_sinembedding #STTN model with sin()/cos() to get positional embedding, the same as "Attention is all your need"

from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error


def predict_and_save_results_my(net, data_loader, data_target_tensor, global_step, _mean, _std, params_path, type):
    '''

    :param net: nn.Module
    :param data_loader: torch.utils.data.utils.DataLoader
    :param data_target_tensor: tensor
    :param epoch: int
    :param _mean: (1, 1, 3, 1)
    :param _std: (1, 1, 3, 1)
    :param params_path: the path for saving the results
    :return:
    '''
    net.train(False)  # ensure dropout layers are in test mode

    with torch.no_grad():

        data_target_tensor = data_target_tensor.cpu().numpy()

        loader_length = len(data_loader)  # nb of batch

        prediction = []  # 存储所有batch的output

        input = []  # 存储所有batch的input

        for batch_index, batch_data in enumerate(data_loader):

            encoder_inputs, labels = batch_data

            input.append(encoder_inputs[:, :, 0:1].cpu().numpy())  # (batch, T', 1)

            outputs = net(encoder_inputs.permute(0, 2, 1, 3))

            prediction.append(outputs.detach().cpu().numpy())

            if batch_index % 100 == 0:
                print('predicting data set batch %s / %s' % (batch_index + 1, loader_length))

        input = np.concatenate(input, 0)

        input = re_normalization(input, _mean, _std)

        prediction = np.concatenate(prediction, 0)  # (batch, T', 1)

        print('input:', input.shape)
        print('prediction:', prediction.shape)
        print('data_target_tensor:', data_target_tensor.shape)
        output_filename = os.path.join(params_path, 'output_epoch_%s_%s' % (global_step, type))
        np.savez(output_filename, input=input, prediction=prediction, data_target_tensor=data_target_tensor)

        # 计算误差
        excel_list = []
        prediction_length = prediction.shape[2]

        for i in range(prediction_length):
            assert data_target_tensor.shape[0] == prediction.shape[0]
            print('current epoch: %s, predict %s points' % (global_step, i))
            mae = mean_absolute_error(data_target_tensor[:, :, i], prediction[:, :, i])
            rmse = mean_squared_error(data_target_tensor[:, :, i], prediction[:, :, i]) ** 0.5
            mape = masked_mape_np(data_target_tensor[:, :, i], prediction[:, :, i], 0)
            print('MAE: %.2f' % (mae))
            print('RMSE: %.2f' % (rmse))
            print('MAPE: %.2f' % (mape))
            excel_list.extend([mae, rmse, mape])

        # print overall results
        mae = mean_absolute_error(data_target_tensor.reshape(-1, 1), prediction.reshape(-1, 1))
        rmse = mean_squared_error(data_target_tensor.reshape(-1, 1), prediction.reshape(-1, 1)) ** 0.5
        mape = masked_mape_np(data_target_tensor.reshape(-1, 1), prediction.reshape(-1, 1), 0)
        print('all MAE: %.2f' % (mae))
        print('all RMSE: %.2f' % (rmse))
        print('all MAPE: %.2f' % (mape))
        excel_list.extend([mae, rmse, mape])
        print(excel_list)

def predict_main(params_filename, global_step, data_loader, data_target_tensor, _mean, _std, type):
    '''

    :param global_step: int
    :param data_loader: torch.utils.data.utils.DataLoader
    :param data_target_tensor: tensor
    :param mean: (1, 1, 3, 1)
    :param std: (1, 1, 3, 1)
    :param type: string
    :return:
    '''

    params_filename = os.path.join(params_path, 'epoch_%s.params' % global_step)
    print('load weight from:', params_filename)

    net.load_state_dict(torch.load(params_filename))

    predict_and_save_results_my(net, data_loader, data_target_tensor, global_step, _mean, _std, params_path, type)



if __name__=='__main__':
    
    ## Best Epoch during Training
    best_epoch = 0   


    ## Same Setting as train_my.py
    params_path = './Experiment/PEMS25_embed_size64' ## Path for saving network parameters
    print('params_path:', params_path)
    filename = './PEMSD7/V_25_r1_d0_w0_astcgn.npz' ## Data generated by prepareData.py
    num_of_hours, num_of_days, num_of_weeks = 1, 0, 0 ## The same setting as prepareData.py
    
    ### Training Hyparameter
    DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
    device = DEVICE
    batch_size = 32
    learning_rate = 0.01
    epochs = 200
    
    ### Generate Data Loader
    train_loader, train_target_tensor, val_loader, val_target_tensor, test_loader, test_target_tensor, _mean, _std = load_graphdata_channel_my(
        filename, num_of_hours, num_of_days, num_of_weeks, DEVICE, batch_size)
    
    ### Adjacency Matrix Import
    adj_mx = pd.read_csv('./PEMSD7/W_25.csv', header = None)
    # adj_mx = import_pkl('/home/wzhangcd@HKUST/Commonpkg/adj_mx_tran_89.pkl')
    adj_mx = np.array(adj_mx)
    A = adj_mx
    A = torch.Tensor(A)
    
    
    ### Training Hyparameter
    in_channels = 1 # Channels of input
    embed_size = 64 # Dimension of hidden embedding features
    time_num = 288 
    num_layers = 3 # Number of ST Block
    T_dim = 12 # Input length, should be the same as prepareData.py
    output_T_dim = 12 # Output Expected length
    heads = 2 # Number of Heads in MultiHeadAttention
    cheb_K = 2 # Order for Chebyshev Polynomials (Eq 2)
    forward_expansion = 4 # Dimension of Feed Forward Network: embed_size --> embed_size * forward_expansion --> embed_size
    
    ### Construct Network
    net = STTransformer(
        A,
        in_channels, 
        embed_size, 
        time_num, 
        num_layers, 
        T_dim, 
        output_T_dim, 
        heads,
        cheb_K,
        forward_expansion)   

    net.to(device)
    
    
    predict_main(params_path, best_epoch, test_loader, test_target_tensor, _mean, _std, 'test')